import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

import numpy as np
from numpy import linalg

class ModelSVMSmooth:
    def __init__(self):
        self.lam = 0.1
        self.inner_prod_times_label = None
        self.w = None

    def get_weight_dimension(self, imgs, labels):
        return len(imgs[0])  # Assuming all images have the same size

    def get_init_weight(self, dim, rand_seed=None):
        self.w = np.zeros(dim)  # Initialize self.w here
        return self.w

    # labels should be 1 or -1
    def gradient(self, imgs, labels, w, sampleIndices):
        val = 0

        # Ensure w is a numpy array
        self.w = np.array(w)
        self.inner_prod_times_label = []

        for i in sampleIndices:
            tmp_inner_prod_times_label = labels[i] * np.inner(self.w, imgs[i])
            self.inner_prod_times_label.append(tmp_inner_prod_times_label)

            if tmp_inner_prod_times_label < 1.0:
                val = val - labels[i] * imgs[i] * (1 - tmp_inner_prod_times_label)

        val = self.lam * self.w + val / len(sampleIndices)
        return val

    def loss(self, imgs, labels, w, sample_indices=None, w_t=None, mu=0.0):
        val = 0
        if sample_indices is None:
            sample_indices = range(0, len(labels))

        for i in sample_indices:
            val = val + pow(max(0.0, 1 - labels[i] * np.inner(w, imgs[i])), 2)

        val = 0.5 * self.lam * pow(linalg.norm(w), 2) + 0.5 * val / len(sample_indices)

        # Add the proximal term
        if w_t is not None:
            val += 0.5 * mu * pow(linalg.norm(w - w_t), 2)
            print(f"Proximal term added in loss with mu: {mu}, w_t: {w_t}, w: {w}")

        return val

    def loss_from_prev_gradient_computation(self, w_t=None, mu=0.0):
        if (self.inner_prod_times_label is None) or (self.w is None):
            raise Exception('No previous gradient computation exists')

        val = 0
        for i in range(0, len(self.inner_prod_times_label)):
            val = val + pow(max(0.0, 1 - self.inner_prod_times_label[i]), 2)

        val = 0.5 * self.lam * pow(linalg.norm(self.w), 2) + 0.5 * val / len(self.inner_prod_times_label)

        # Add the proximal term
        if w_t is not None:
            val += 0.5 * mu * pow(linalg.norm(self.w - w_t), 2)
            print(f"Proximal term added in loss_from_prev_gradient_computation with mu: {mu}, w_t: {w_t}, w: {self.w}")

        return val

    def accuracy(self, imgs, labels, w):
        val = 0
        for i in range(1, len(labels)):
            if labels[i] * np.inner(w, imgs[i]) > 0:
                val += 1
        val /= len(labels)

        return val
    
    ### 下面是我自己加的

    def get_params(self):
        return self.w  # return the model parameters, i.e., w

    def set_params(self, params):
        self.w = np.array(params)  # Ensure self.w is a NumPy array

    def get_gradients(self, train_data, model_len):
        imgs, labels = train_data['x'], train_data['y']
        sample_indices = range(len(labels))
        gradients = self.gradient(imgs, labels, self.w, sample_indices)
        return len(labels), gradients

    @property
    def size(self):
        return self.w.size  # Returns the number of parameters in the model

    def solve_inner(self, train_data, num_epochs=1, batch_size=10):
        imgs, labels = train_data['x'], train_data['y']
        num_samples = len(labels)
        for epoch in range(num_epochs):
            for i in range(0, num_samples, batch_size):
                end = i + batch_size
                if end > num_samples:
                    end = num_samples
                sample_indices = range(i, end)
                gradients = self.gradient(imgs, labels, self.w, sample_indices)
                self.w = self.w - self.lam * gradients  # Update rule
        return self.get_params(), self.lam * num_samples  # Return updated parameters and computation cost

    def test(self, test_data):
        """Evaluate the model on the test data."""
        imgs, labels = test_data['x'], test_data['y']
        tot_correct = 0
        num_samples = len(labels)
        for i in range(num_samples):
            if labels[i] * np.inner(self.w, imgs[i]) > 0:
                tot_correct += 1
        accuracy = tot_correct / num_samples
        return tot_correct, num_samples, accuracy

# Additional methods can be added as needed